import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Scanner;
import llama.*;
import static java.lang.foreign.ValueLayout.*;
import static llama.llama_h.*;

// 研究方向:
// 1. token词表的导出和分析(Qwen3-8B包含151936种token;69k+个纯英文token;25k+个含汉字的字词;8k+个单汉字;16k+个多汉字的词)
// 2. 输入文字的token分词
// 3. 输出各token的logits数据导出和分析
// 4. sample策略(MinP:只考虑与logitMax差距ln(p)以内的;Temp:logit值除以temp;SoftMax:e^(logit-logitMax)得到(0,1]的概率再线性归一化
// 5. 上文超过contextSize限制时的处理
public class Llama {
	static void main() throws IOException {
		final int gpuLayers = 99;
		final int contextSize = 4096;
		final int threads = Runtime.getRuntime().availableProcessors() >> 1;

		final var modelPath = "d:/models/Qwen3-8B-Q4_K_M.gguf";
		final var prompt = "You are a helpful assistant.";
		final var chatFmtPmt = "<|im_start|>system\n%s";
		final var chatFmtUser = "<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n";

//		final var modelPath = "d:/models/Seed-X-Instruct-7B.Q4_K_M.gguf";
//		final var prompt = "Translate the following English sentence into Chinese:";
//		final var chatFmtPmt = "%s\n";
//		final var chatFmtUser = "%s <zh><s>";

		final MemorySegment model, ctx, mem;
		try (var arena = Arena.ofConfined()) {
			ggml_backend_load_all();
			var modelParams = llama_model_default_params(arena);
			llama_model_params.n_gpu_layers(modelParams, gpuLayers);
			model = llama_model_load_from_file(arena.allocateFrom(modelPath, StandardCharsets.UTF_8), modelParams);
			if (model.equals(MemorySegment.NULL))
				throw new AssertionError("llama_model_load_from_file failed: " + modelPath);
			var ctxParams = llama_context_default_params(arena);
			llama_context_params.n_ctx(ctxParams, contextSize);
			llama_context_params.n_batch(ctxParams, contextSize);
			llama_context_params.n_threads(ctxParams, threads);
			llama_context_params.n_threads_batch(ctxParams, threads);
			ctx = llama_init_from_model(model, ctxParams);
			if (ctx.equals(MemorySegment.NULL))
				throw new AssertionError("llama_init_from_model failed: " + modelPath);
			mem = llama_get_memory(ctx);
			if (mem.equals(MemorySegment.NULL))
				throw new AssertionError("llama_get_memory failed: " + modelPath);
		}
		// exportVocab(model, modelPath + ".vocab.txt");
		final var tpl = llama_model_chat_template(model, MemorySegment.NULL);
		if (tpl != null && tpl.address() != 0)
			System.out.println(tpl.getString(0, StandardCharsets.UTF_8));
		System.out.println("= " + prompt);
		try (final var scanner = new Scanner(System.in)) {
			final var vocab = llama_model_get_vocab(model);
			final int vocabNum = llama_vocab_n_tokens(vocab);
			final var arena = Arena.ofAuto();
			final var batch = llama_batch.allocate(arena);
			final var strBytes = new byte[256];
			var strBuf = arena.allocate(4096);
			var tokenBuf = arena.allocate(JAVA_INT, 4096);
			int histoNum = 0;
			for (var system = String.format(chatFmtPmt, prompt); ; system = null) {
				System.out.print("> ");
				var user = scanner.nextLine().trim();
				if (user.isEmpty())
					break;
				user = String.format(chatFmtUser, user);
				if (system != null)
					user = system + user;
				var userBytes = user.getBytes(StandardCharsets.UTF_8);
				int strLen = userBytes.length;
				if (strBuf.byteSize() < strLen)
					strBuf = arena.allocate(strLen);
				MemorySegment.copy(userBytes, 0, strBuf, JAVA_BYTE, 0, strLen);
				int tokenNum = -llama_tokenize(vocab, strBuf, strLen, MemorySegment.NULL, 0, true, true);
				if (tokenBuf.byteSize() >> 2 < tokenNum)
					tokenBuf = arena.allocate(JAVA_INT, tokenNum);
				int r = llama_tokenize(vocab, strBuf, strLen, tokenBuf, tokenNum, true, true);
				if (r != tokenNum)
					throw new AssertionError("llama_tokenize failed: " + r + " != " + tokenNum);
				exportToken(userBytes, vocab, tokenBuf, tokenNum, "llama_tokens.log");
				llama_batch.token(batch, tokenBuf);
				System.out.print("< ");
				for (int bufLen = 0; ; ) {
					if (histoNum + tokenNum >= contextSize) {
						int keep = 0; // maybe prompt token count + 1(for <|im_start|>)
						int discard = Math.max((histoNum - keep) >> 1, histoNum + tokenNum - contextSize + 1);
						llama_memory_seq_rm(mem, 0, keep, keep + discard);
						llama_memory_seq_add(mem, 0, keep + discard, histoNum, -discard); // move pos
						histoNum -= discard;
						System.err.println("[discard " + discard + "]");
					}
					histoNum += tokenNum;
					llama_batch.n_tokens(batch, tokenNum);
					r = llama_decode(ctx, batch);
					if (r != 0)
						throw new AssertionError("llama_decode failed: " + r);
					var logits = llama_get_logits_ith(ctx, -1).reinterpret(vocabNum * JAVA_FLOAT.byteSize());
					int selectToken = 0;
					var max = -Float.MAX_VALUE;
					for (int i = 0; i < vocabNum; i++) {
						var v = logits.get(JAVA_FLOAT, i * JAVA_FLOAT.byteSize());
						if (max < v) {
							max = v;
							selectToken = i;
						}
					}
					exportLogits(vocab, logits, 100, "llama_logits.log");
					if (llama_vocab_is_eog(vocab, selectToken)) {
						if (bufLen > 0) {
							for (int i = 0; i < bufLen; i++) {
								System.out.print(i == 0 ? '[' : ' ');
								System.out.printf("%02X", strBytes[i] & 0xff);
							}
							System.out.print(']');
						}
						break;
					}
					tokenBuf.set(JAVA_INT, 0, selectToken);
					tokenNum = 1;
					r = llama_token_to_piece(vocab, selectToken, strBuf, (int)strBuf.byteSize(), 0, true);
					if (r <= 0 || r > strBytes.length - bufLen)
						System.out.print("<" + selectToken + "," + r + ">");
					else {
						MemorySegment.copy(strBuf, JAVA_BYTE, 0, strBytes, bufLen, r);
						bufLen += r;
						if (checkUtf8(strBytes, 0, bufLen) == bufLen) {
							System.out.print(new String(strBytes, 0, bufLen, StandardCharsets.UTF_8));
							bufLen = 0;
						}
					}
				}
				System.out.println();
			}
		}
		llama_free(ctx);
		llama_model_free(model);
		if (exportTokenStream != null)
			exportTokenStream.close();
		if (exportLogitsStream != null)
			exportLogitsStream.close();
		System.out.println("= END");
	}

	public static int checkUtf8(byte[] buf, int i, int e) {
		while (i < e) {
			int b = buf[i] & 0xff;
			if (b < 0xc0) // 0xxx xxxx | 10xx xxxx
				i++;
			else if (b < 0xe0) { // 110x xxxx
				i += 2;
				if (i > e)
					return i - 2;
			} else if (b < 0xf0) { // 1110 xxxx
				i += 3;
				if (i > e)
					return i - 3;
			} else if (b < 0xf8) { // 1111 0xxx
				i += 4;
				if (i > e)
					return i - 4;
			} else
				break;
		}
		return i;
	}

	public static void exportVocab(MemorySegment model, String fileName) throws IOException {
		final var vocab = llama_model_get_vocab(model);
		final int vocabNum = llama_vocab_n_tokens(vocab);
		try (var arena = Arena.ofConfined(); var fos = new BufferedOutputStream(new FileOutputStream(fileName))) {
			int bufLen = 256;
			int maxLen = Integer.MIN_VALUE;
			var buf = arena.allocate(bufLen);
			var bytes = new byte[bufLen];
			for (int i = 0; i < vocabNum; i++) {
				fos.write(String.valueOf(i).getBytes(StandardCharsets.ISO_8859_1));
				fos.write(':');
				int r = llama_token_to_piece(vocab, i, buf, bufLen, 0, true);
				if (maxLen < r)
					maxLen = r;
				if (r <= 0 || r > bufLen)
					fos.write(("<" + r + ">").getBytes(StandardCharsets.ISO_8859_1));
				else {
					MemorySegment.copy(buf, JAVA_BYTE, 0, bytes, 0, r);
					var t = new String(bytes, 0, r, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8);
					if (r == 1 && bytes[0] >= 0x20 && bytes[0] < 0x7f
							|| r > 1 && t.length == r && Arrays.equals(t, 0, r, bytes, 0, r)) {
						fos.write(' ');
						fos.write(bytes, 0, r);
					} else {
						for (int j = 0; j < r; j++) {
							fos.write(j == 0 ? '[' : ' ');
							fos.write(String.format("%02X", bytes[j] & 0xff).getBytes(StandardCharsets.ISO_8859_1));
						}
						fos.write(']');
					}
				}
				fos.write('\n');
			}
			fos.write(("maxLen=" + maxLen + "\n").getBytes(StandardCharsets.ISO_8859_1));
			final var tpl = llama_model_chat_template(model, MemorySegment.NULL);
			if (tpl != null && tpl.address() != 0)
				fos.write(tpl.getString(0, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8));
		}
	}

	private static OutputStream exportTokenStream;

	public static void exportToken(byte[] str, MemorySegment vocab, MemorySegment tokenBuf, int tokenNum,
								   String logFileName) throws IOException {
		try (var arena = Arena.ofConfined()) {
			var fos = exportTokenStream;
			if (fos == null)
				exportTokenStream = fos = new BufferedOutputStream(new FileOutputStream(logFileName, true));
			fos.write(str);
			fos.write("\n------------------------\n".getBytes(StandardCharsets.ISO_8859_1));
			int bufLen = 256;
			var buf = arena.allocate(bufLen);
			var bytes = new byte[bufLen];
			for (int i = 0; i < tokenNum; i++) {
				int token = tokenBuf.get(JAVA_INT, i * JAVA_INT.byteSize());
				fos.write(("(" + token + ")").getBytes(StandardCharsets.ISO_8859_1));
				int r = llama_token_to_piece(vocab, token, buf, bufLen, 0, true);
				if (r <= 0 || r > bufLen)
					fos.write(("<" + r + ">").getBytes(StandardCharsets.ISO_8859_1));
				else {
					MemorySegment.copy(buf, JAVA_BYTE, 0, bytes, 0, r);
					var t = new String(bytes, 0, r, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8);
					if (r == 1 && (bytes[0] >= 0x20 && bytes[0] < 0x7f || bytes[0] == '\n')
							|| r > 1 && t.length == r && Arrays.equals(t, 0, r, bytes, 0, r)) {
						fos.write(bytes, 0, r);
					} else {
						for (int j = 0; j < r; j++) {
							fos.write(j == 0 ? '[' : ' ');
							fos.write(String.format("%02X", bytes[j] & 0xff).getBytes(StandardCharsets.ISO_8859_1));
						}
						fos.write(']');
					}
				}
			}
			fos.write("\n========================\n".getBytes(StandardCharsets.ISO_8859_1));
			fos.flush();
		}
	}

	private static OutputStream exportLogitsStream;

	public static void exportLogits(MemorySegment vocab, MemorySegment logits, int n,
									String logFileName) throws IOException {
		var elems = new double[(int)(logits.byteSize() / JAVA_FLOAT.byteSize())];
		for (int i = 0, e = elems.length; i < e; i++) {
			elems[i] = Double.longBitsToDouble(((long)Float.floatToRawIntBits(
					logits.get(JAVA_FLOAT, i * JAVA_FLOAT.byteSize())) << 32) | (~i & 0xffff_ffffL));
		}
		Arrays.sort(elems);
		try (var arena = Arena.ofConfined()) {
			var fos = exportLogitsStream;
			if (fos == null)
				exportLogitsStream = fos = new BufferedOutputStream(new FileOutputStream(logFileName, true));
			fos.write("""
					------------------------
					idx token logit     word
					------------------------
					""".getBytes(StandardCharsets.ISO_8859_1));
			int bufLen = 256;
			var buf = arena.allocate(bufLen);
			var bytes = new byte[bufLen];
			for (int i = elems.length, e = Math.max(i - n, 0); e <= --i; ) {
				var elem = Double.doubleToRawLongBits(elems[i]);
				int token = ~(int)elem;
				var logit = Float.intBitsToFloat((int)(elem >> 32));
				fos.write(String.format("%2d:%6d %9.6f", elems.length - i - 1, token, logit)
						.getBytes(StandardCharsets.ISO_8859_1));
				int r = llama_token_to_piece(vocab, token, buf, bufLen, 0, true);
				if (r <= 0 || r > bufLen)
					fos.write(("<" + r + ">").getBytes(StandardCharsets.ISO_8859_1));
				else {
					MemorySegment.copy(buf, JAVA_BYTE, 0, bytes, 0, r);
					var t = new String(bytes, 0, r, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8);
					if (r == 1 && bytes[0] >= 0x20 && bytes[0] < 0x7f
							|| r > 1 && t.length == r && Arrays.equals(t, 0, r, bytes, 0, r)) {
						fos.write(' ');
						fos.write(bytes, 0, r);
					} else {
						for (int j = 0; j < r; j++) {
							fos.write(j == 0 ? '[' : ' ');
							fos.write(String.format("%02X", bytes[j] & 0xff).getBytes(StandardCharsets.ISO_8859_1));
						}
						fos.write(']');
					}
				}
				fos.write('\n');
			}
			fos.write(String.format("mid:%15.6f\n", Float.intBitsToFloat((int)(Double.doubleToRawLongBits
					(elems[elems.length >> 1]) >> 32))).getBytes(StandardCharsets.ISO_8859_1));
			fos.write(String.format("last:%14.6f\n", Float.intBitsToFloat((int)(Double.doubleToRawLongBits
					(elems[0]) >> 32))).getBytes(StandardCharsets.ISO_8859_1));
			fos.flush();
		}
	}
}
